/*
 * Copyright (c) Hisilicon Technologies Co., Ltd. 2023-2023. All rights reserved.
 * Description: quantize sever
 */

#include "omg/quantize_optimizer/quantize_saver.h"

#include <algorithm>

#include "graph/types.h"
#include "graph/op/all_ops.h"

#include "infra/base/assertion.h"
#include "infra/base/securestl.h"

#include "framework/graph/core/node/node_spec.h"
#include "framework/graph/core/node/node_walker.h"
#include "framework/graph/core/cgraph/graph_list_walker.h"
#include "framework/graph/core/cgraph/graph_modifier.h"
#include "framework/graph/utils/attr_utils.h"
#include "framework/graph/op/internal_defs.h"

#include "omg/quantize_optimizer/general_ir_quantize_saver.h"
#include "omg/quantize_optimizer/legacy_quantize_saver.h"
#include "omg/quantize_optimizer/quantize_util.h"
#include "omg/params.h"

using namespace std;

namespace hiai {
namespace {
struct QuantDataType {
    ge::DataType inputDataType;
    ge::DataType weightDataType;
};

#ifndef AI_SUPPORT_GENERAL_QUANTIZE_IR
bool IsCompatibleQuantType(ge::DataType inputDataType, ge::DataType weightDataType)
{
    static const vector<QuantDataType> compatibleQuantTypes = {
        { ge::DT_FLOAT, ge::DT_INT8 }, // 单边量化
        { ge::DT_FLOAT16, ge::DT_INT8 }, // 单边量化
        { ge::DT_UINT8, ge::DT_INT8 }, // U8S8量化
        { ge::DT_INT4, ge::DT_INT4 }, // S4S4量化
        { ge::DT_UINT8, ge::DT_2BIT } // U8S2量化
    };

    return std::find_if(compatibleQuantTypes.cbegin(), compatibleQuantTypes.cend(),
        [&inputDataType, &weightDataType](const QuantDataType& quantDType) {
            return (quantDType.inputDataType == inputDataType && quantDType.weightDataType == weightDataType);
        }) != compatibleQuantTypes.cend();
}

bool IsFeaturemapPerChannelQuant(const vector<QuantizeParams>& inputQuantParams, uint32_t inputIndex)
{
    for (auto inputQuantParam : inputQuantParams) {
        if (inputQuantParam.index == inputIndex) {
            size_t operCount = inputQuantParam.operatorParams.size();
            if (inputQuantParam.operatorParams[operCount - 1].scale.size() > 1) {
                return true;
            }
        }
    }

    return false;
}
#endif

bool CheckNeedCompatibleQuantV1(const ge::Node& node, const QuantizeConfig& quantizeV1Config)
{
#ifdef AI_SUPPORT_GENERAL_QUANTIZE_IR
    (void)node;
    (void)quantizeV1Config;
    return false;
#else
    string nodeType = node.ROLE(NodeSpec).Type();
    if (nodeType == hiai::op::ConvTranspose::TYPE || nodeType == hiai::op::GemmD::TYPE) {
        return false; // 这2个算子从V210平台开始支持，走IR小算子形式
    }
    if (QuantizeUtil::IsSupportQuantOpType(nodeType)) {
        return IsCompatibleQuantType(quantizeV1Config.inputDataType, quantizeV1Config.weightDataType);
    }
    return false;
#endif
}

void GetInputAndWeightAnchorIndex(const ge::Node& node, uint32_t& inputIndex, uint32_t& weightIndex)
{
    inputIndex = 0;
    weightIndex = 1;
    if (node.ROLE(NodeSpec).Type() == hiai::op::ConvTranspose::TYPE) {
        inputIndex = 2;  // 根据IR定义，ConvTranspose算子input anchor的索引为2
        weightIndex = 1; // 根据IR定义，ConvTranspose算子filter anchor的索引为1
    }
}

bool CheckNeedCompatibleQuantV2(const ge::Node& node, const QuantizeV2Config& quantizeV2Config)
{
#ifdef AI_SUPPORT_GENERAL_QUANTIZE_IR
    (void)node;
    (void)quantizeV2Config;
    return false;
#else
    if (quantizeV2Config.isOneSideQuantize) {
        FMK_LOGI("quant is oneside quant, node:%s.", node.ROLE(NodeSpec).Name().c_str());
        return true; // 单边量化场景，走兼容处理流程
    }
    string nodeType = node.ROLE(NodeSpec).Type();
    if (nodeType == hiai::op::ConvTranspose::TYPE || nodeType == hiai::op::GemmD::TYPE) {
        return false; // 这2个算子从V210平台开始支持，走IR小算子形式
    }
    if (!QuantizeUtil::IsSupportQuantOpType(nodeType)) {
        return false;  // 除老的几个量化算子外，其他所有新算子均走IR小算子流程
    }
    if (quantizeV2Config.inputQuantParams.size() > 2) { // 除featuremap和filter权值量化外，还有bias权值量化
        return false;
    }
    for (auto inputQuantParam : quantizeV2Config.inputQuantParams) {
        if (inputQuantParam.operatorParams.size() > 1) {
            return false;  // 插多个算子，走IR小算子流程
        }
    }
    if (quantizeV2Config.outputQuantParams.size() != 1) {
        return false;   // 多个输出，走IR小算子流程
    }
    if (quantizeV2Config.outputQuantParams[0].operatorParams.size() > 1) {
        return false;  // 输出后面插多个算子，走IR小算子流程
    }

    if (quantizeV2Config.outputQuantParams[0].operatorParams[0].operType != OperatorType::DEQUANTIZE) {
        return false;  // 对于输出需要做Quant、Requant、AntiQuant等处理，走IR小算子流程
    }

    uint32_t inputIndex = 0;
    uint32_t weightIndex = 1;
    GetInputAndWeightAnchorIndex(node, inputIndex, weightIndex);
    if (quantizeV2Config.inputQuantParams.size() == 1 && quantizeV2Config.inputQuantParams[0].index == weightIndex) {
        // 级联量化场景，仅包含权重的量化参数,说明上一个量化算子输出做了重量化，走IR小算子形式
        return false;
    }
    if (IsFeaturemapPerChannelQuant(quantizeV2Config.inputQuantParams, inputIndex)) {
        return false; // Featuremap Per Channel量化，走IR小算子流程
    }

    ge::DataType inputDataType = quantizeV2Config.inputQuantParams[0].operatorParams[0].dataType;
    ge::DataType weightDataType = quantizeV2Config.inputQuantParams[1].operatorParams[0].dataType;
    if (quantizeV2Config.inputQuantParams[0].index == weightIndex) {
        weightDataType = quantizeV2Config.inputQuantParams[0].operatorParams[0].dataType;
        inputDataType = quantizeV2Config.inputQuantParams[1].operatorParams[0].dataType;
    }
    return IsCompatibleQuantType(inputDataType, weightDataType);
#endif
}

bool ConvertConfigV1ToV2(QuantizeConfig& quantizeConfig, const ge::Node& node, QuantizeV2Config& quantizeV2Config)
{
    QuantizeParams inputQuantPara {0, {}};
    OperatorParam operParam2 = {0, OperatorType::QUANTIZE, quantizeConfig.weightDataType, {}, {}};
    QuantizeParams weightQuantPara {1, {std::move(operParam2)}};
    if (node.ROLE(NodeSpec).Type() == hiai::op::ConvTranspose::TYPE) {
        inputQuantPara.index = 2;  // 根据IR定义，ConvTranspose算子input anchor的索引为2
        weightQuantPara.index = 1; // 根据IR定义，ConvTranspose算子filter anchor的索引为1
    }

    quantizeV2Config.isOneSideQuantize =
        ((quantizeConfig.inputDataType == ge::DT_FLOAT || quantizeConfig.inputDataType == ge::DT_FLOAT16) &&
        quantizeConfig.weightDataType == ge::DT_INT8);
    if (!quantizeV2Config.isOneSideQuantize) {
        OperatorParam operParam1 = {0, OperatorType::QUANTIZE, quantizeConfig.inputDataType, {}, {}};
        if (inputQuantPara.operatorParams.size() > 0) {
            operParam1.operIndex = 1;
        }
        operParam1.scale.swap(quantizeConfig.inputScale);
        operParam1.offset.swap(quantizeConfig.inputOffset);
        inputQuantPara.operatorParams.push_back(std::move(operParam1));
        quantizeV2Config.inputQuantParams.push_back(std::move(inputQuantPara));

        OperatorParam operParam3 = {0, OperatorType::DEQUANTIZE, ge::DT_INT32, {}, {}};
        QuantizeParams outputQuantPara{0, {std::move(operParam3)}};
        quantizeV2Config.outputQuantParams.push_back(std::move(outputQuantPara));
    }
    weightQuantPara.operatorParams[0].scale.swap(quantizeConfig.weightScale);
    weightQuantPara.operatorParams[0].offset.swap(quantizeConfig.weightOffset);
    quantizeV2Config.inputQuantParams.push_back(std::move(weightQuantPara));

    quantizeV2Config.hasQuantInfoExt = quantizeConfig.hasQuantInfoExt;
    quantizeV2Config.quantInfoExt = quantizeConfig.quantInfoExt;
    return true;
}

bool ConvertConfigV2ToV1(QuantizeV2Config& quantizeV2Config, const ge::Node& node, QuantizeConfig& quantizeV1Config)
{
    if (quantizeV2Config.isOneSideQuantize) {
        quantizeV1Config.inputDataType = ge::DT_FLOAT16;
        quantizeV1Config.inputScale.push_back(0.0f);
        quantizeV1Config.inputOffset.push_back(0.0f);
    }
    uint32_t inputIndex = 0;
    uint32_t weightIndex = 1;
    GetInputAndWeightAnchorIndex(node, inputIndex, weightIndex);
    for (auto inputPara : quantizeV2Config.inputQuantParams) {
        if (inputPara.index == inputIndex) {
            quantizeV1Config.inputScale.swap(inputPara.operatorParams[0].scale);
            quantizeV1Config.inputOffset.swap(inputPara.operatorParams[0].offset);
            quantizeV1Config.inputDataType = inputPara.operatorParams[0].dataType;
        } else if (inputPara.index == weightIndex) {
            quantizeV1Config.weightScale.swap(inputPara.operatorParams[0].scale);
            quantizeV1Config.weightOffset.swap(inputPara.operatorParams[0].offset);
            quantizeV1Config.weightDataType = inputPara.operatorParams[0].dataType;
        } else {
            FMK_LOGE("Input quant index is illegal:%u, node:%s.", inputPara.index, node.ROLE(NodeSpec).Name().c_str());
            return false;
        }
    }

    quantizeV1Config.hasQuantInfoExt = quantizeV2Config.hasQuantInfoExt;
    quantizeV1Config.quantInfoExt = quantizeV2Config.quantInfoExt;
    return true;
}

hiai::Status UpdateQuantizeConfig(
    const OpDesc& opDesc, float transScale, vector<float>& inputScale, vector<float>& weightScale)
{
    if (transScale < ZERO_EPS) {
        FMK_LOGE("transScale size is less than 1, op name:%s", opDesc.GetName().c_str());
        return hiai::FAILED;
    }
    float* inputScaleValue = inputScale.data();
    size_t inputScaleSize = inputScale.size();
    for (size_t i = 0; i < inputScaleSize; i++) {
        inputScaleValue[i] /= transScale;
    }
    float* weightScaleValue = weightScale.data();
    size_t weightScaleSize = weightScale.size();
    for (size_t i = 0; i < weightScaleSize; i++) {
        weightScaleValue[i] *= transScale;
    }
    return hiai::SUCCESS;
}

hiai::Status UpDateWeightScale(const OpDesc& opDesc, const vector<float>& transScale, vector<float>& weightScale)
{
    if (transScale.size() != weightScale.size()) {
        FMK_LOGE("transScale size(%u) is not equal weight scale size(%u), op name:%s", transScale.size(),
            weightScale.size(), opDesc.GetName().c_str());
        return hiai::FAILED;
    }
    float* weightScaleValue = weightScale.data();
    const float* transScaleValue = transScale.data();
    size_t weightScaleSize = weightScale.size();

    for (size_t i = 0; i < weightScaleSize; ++i) {
        weightScaleValue[i] *= transScaleValue[i];
    }
    return hiai::SUCCESS;
}

Status UpdateQuantizeInfo(
    const ge::OpDesc& opDesc, vector<float>& inputScale, vector<float>& inputOffset, vector<float>& weightScale)
{
#ifdef TINY_SUPPORT
    const string& opType = opDesc.GetType();
    if ((Params::Instance()->IsTiny()) && (opType == hiai::op::Convolution::TYPE) &&
        (opDesc.HasAttr("ATTR_TINY_CONV_CAL"))) {
        inputScale.clear();
        inputOffset.clear();
        inputScale.push_back(1.0f);
        inputOffset.push_back(0.0f);
    }
#else
    (void)inputOffset;
#endif
    if (QuantizeUtil::HasTransScale(opDesc)) {
        vector<float> transScale;
        if (!QuantizeUtil::GetTransScale(opDesc, transScale)) {
            FMK_LOGE("Get transScale failed, op name:%s", opDesc.GetName().c_str());
            return hiai::FAILED;
        }
        if (UpDateWeightScale(opDesc, transScale, weightScale) != SUCCESS) {
            FMK_LOGE("Node: %s update weightscale failed.", opDesc.GetName().c_str());
            return hiai::FAILED;
        }
    }
    if (QuantizeUtil::HasPowerTransScale(opDesc)) {
        float transScale;
        if (!QuantizeUtil::GetPowerTransScale(opDesc, transScale)) {
            FMK_LOGE("Get transScale failed, op name:%s", opDesc.GetName().c_str());
            return hiai::FAILED;
        }
        if (UpdateQuantizeConfig(opDesc, transScale, inputScale, weightScale) != SUCCESS) {
            FMK_LOGE("Node: %s update compress conf failed.", opDesc.GetName().c_str());
            return hiai::FAILED;
        }
    }
    // 对于存在负值的，将其提取到权值中
    for (size_t i = 0; i < weightScale.size(); ++i) {
        weightScale[i] = (weightScale[i] > 0) ? weightScale[i] : (-1.0f * weightScale[i]);
    }

    return hiai::SUCCESS;
}
}

Status QuantizeSaver::SaveOpQuantV1Params(
    QuantizeConfig& quantizeConfig, ge::Node* node, int64_t weightDataAddr)
{
    HIAI_EXPECT_EXEC(UpdateQuantizeInfo(node->ROLE(NodeSpec).OpDesc(), quantizeConfig.inputScale,
        quantizeConfig.inputOffset, quantizeConfig.weightScale));
    if (CheckNeedCompatibleQuantV1(*node, quantizeConfig)) {
        return LegacyQuantizeSaver::SaveOpQuantParams(quantizeConfig, node, weightDataAddr);
    }
    QuantizeV2Config quantizeV2Config;
    HIAI_EXPECT_TRUE(ConvertConfigV1ToV2(quantizeConfig, *node, quantizeV2Config));
    return GeneralIRQuantizeSaver::SaveOpQuantParams(quantizeV2Config, node, weightDataAddr);
}

Status QuantizeSaver::SaveOpQuantV2Params(
    QuantizeV2Config& quantizeV2Config, ge::Node* node, int64_t weightDataAddr)
{
    if (CheckNeedCompatibleQuantV2(*node, quantizeV2Config)) {
        FMK_LOGI("QuantizeSaver::SaveOpQuantV2Params run LegacyQuantizeSaver, node:%s.",
            node->ROLE(NodeSpec).Name().c_str());
        QuantizeConfig quantizeV1Config;
        HIAI_EXPECT_TRUE(ConvertConfigV2ToV1(quantizeV2Config, *node, quantizeV1Config));
        return LegacyQuantizeSaver::SaveOpQuantParams(quantizeV1Config, node, weightDataAddr);
    }
    FMK_LOGI(
        "QuantizeSaver::SaveOpQuantV2Params run GeneralIRQuantizeSaver, node:%s.", node->ROLE(NodeSpec).Name().c_str());
    return GeneralIRQuantizeSaver::SaveOpQuantParams(quantizeV2Config, node, weightDataAddr);
}
}